import pandas as pd
import numpy as np
import os
import time
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    roc_auc_score, accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, balanced_accuracy_score
)

import matplotlib.pyplot as plt
import seaborn as sns

def compute_specificity(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    return tn / (tn + fp) if (tn + fp) > 0 else 0.0

def external_crossval_with_fs(train_path, val_path, test_path, models_dict,
                               selected_features_dict, output_dir, n_splits=5, random_state=42):
    
    start_time = time.time()
    # Step 1: Ricostruisci il dataset completo
    train_df = pd.read_csv(train_path)
    val_df = pd.read_csv(val_path)
    test_df = pd.read_csv(test_path)
    full_df = pd.concat([train_df, val_df, test_df], ignore_index=True)
    
    # Assumiamo che l'etichetta sia nella colonna 'label'
    X_all = full_df.drop(columns=['Label'])
    y_all = full_df['Label']

    # Crea lista per salvare i risultati
    results = []

    for fs_method, classifiers in models_dict.items():
        selected_features = selected_features_dict[fs_method]
        X_selected = X_all[selected_features].copy()

        for clf_name, clf in classifiers.items():
            skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

            metrics_list = {
                'accuracy': [], 'precision': [], 'recall': [], 'f1': [],
                'balanced_accuracy': [], 'roc_auc': [],
                'tp': [], 'tn': [], 'fp': [], 'fn': [], 'specificity': []
            }

            for train_index, test_index in skf.split(X_selected, y_all):
                X_train, X_test = X_selected.iloc[train_index], X_selected.iloc[test_index]
                y_train, y_test = y_all.iloc[train_index], y_all.iloc[test_index]

                model = clf
                model.fit(X_train, y_train)
                y_pred = model.predict(X_test)
                y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None

                tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()

                metrics_list['accuracy'].append(accuracy_score(y_test, y_pred))
                metrics_list['precision'].append(precision_score(y_test, y_pred, zero_division=0))
                metrics_list['recall'].append(recall_score(y_test, y_pred))
                metrics_list['f1'].append(f1_score(y_test, y_pred))
                metrics_list['balanced_accuracy'].append(balanced_accuracy_score(y_test, y_pred))
                metrics_list['roc_auc'].append(roc_auc_score(y_test, y_proba) if y_proba is not None else np.nan)
                metrics_list['tp'].append(tp)
                metrics_list['tn'].append(tn)
                metrics_list['fp'].append(fp)
                metrics_list['fn'].append(fn)
                metrics_list['specificity'].append(compute_specificity(y_test, y_pred))

            result_row = {
                'Feature_Selection': fs_method,
                'Classifier': clf_name
            }

            # Combina media e std in un'unica cella
            for metric, values in metrics_list.items():
                mean = np.mean(values)
                std = np.std(values)
                result_row[metric] = f"{mean:.3f} ± {std:.3f}"

            results.append(result_row)

    # Salvataggio Excel
    results_df = pd.DataFrame(results)
    output_path = os.path.join(output_dir, "external_crossval_results.xlsx")
    results_df.to_excel(output_path, index=False)

    # Plotting aggiornato
    sns.set_theme(style="whitegrid")
    metriche_da_plottare = ['roc_auc', 'balanced_accuracy']

    for metrica in metriche_da_plottare:
        plt.figure(figsize=(10, 6))

        plot_df = results_df.copy()
        plot_df['Metodo_Clf'] = plot_df['Feature_Selection'] + " - " + plot_df['Classifier']
    
        # Estrai solo il valore numerico per il plot (prima del ±)
        plot_df[metrica + "_val"] = plot_df[metrica].str.extract(r"([0-9.]+)").astype(float)

        # Plot verticale (x=modelli, y=valore metrica)
        sns.barplot(x='Metodo_Clf', y=metrica + "_val", data=plot_df, color='C0')
        plt.xticks(rotation=90)
        plt.ylabel(metrica.upper())
        plt.xlabel("Metodo di Feature Selection + Classificatore")
        plt.title(f"Performance media ({metrica.upper()}) dopo CV esterna")
        plt.tight_layout()

        plot_path = os.path.join(output_dir, f"{metrica}_cvest_barplot.png")
        plt.savefig(plot_path)
        plt.close()

    elapsed = time.time() - start_time
    print(f"Tempo di esecuzione: {elapsed:.2f} secondi")    
    print(f"Risultati salvati in: {output_path}")
    return results_df
